import os


def change_list_args(list_args, arg_name, action="change", value=None):
    assert action in ["add", "remove", "change"]

    if action == "add" or action == "remove":
        raise NotImplementedError("Not implemented yet")
    else:
        # * search
        index_arg = list_args.index(arg_name) if arg_name in list_args else None

        assert index_arg is not None, f"Argument {arg_name} not found in list_args"

        old_value = list_args[index_arg + 1]
        assert not old_value.startswith("-"), f"Argument {arg_name} has no value"
        list_args[index_arg + 1] = value

    return list_args


def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, "{} argument is None".format(arg)


def validate_args(args, defaults={}):
    """
    Edited from megatron.training.arguments.validate_args
    """
    import torch

    # Temporary
    assert args.non_persistent_ckpt_type in ["global", None], "Currently only global checkpoints are supported"

    if args.encoder_tensor_model_parallel_size > 0:
        assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
        assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
        assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder."

    if args.encoder_pipeline_model_parallel_size > 0 and args.encoder_tensor_model_parallel_size == 0:
        args.encoder_tensor_model_parallel_size = args.tensor_model_parallel_size

    encoder_model_size = args.encoder_tensor_model_parallel_size * args.encoder_pipeline_model_parallel_size * args.context_parallel_size
    decoder_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size
    total_model_size = encoder_model_size + decoder_model_size

    # Total model size.
    assert args.world_size % total_model_size == 0, f"world size ({args.world_size}) is not divisible by total_model_size ({encoder_model_size=} + {decoder_model_size=})"

    # Pipeline model parallel size.
    args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size - 1 if args.standalone_embedding_stage else args.pipeline_model_parallel_size

    args.data_parallel_size = args.world_size // total_model_size

    # Checks.
    if args.rank == 0:
        print(
            "using world size: {}, data-parallel size: {}, "
            "context-parallel size: {}, "
            "tensor-model-parallel size: {}, "
            "encoder-tensor-model-parallel size: {}, "
            "pipeline-model-parallel size: {}, "
            "encoder-pipeline-model-parallel size: {}".format(
                args.world_size,
                args.data_parallel_size,
                args.context_parallel_size,
                args.tensor_model_parallel_size,
                args.encoder_tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.encoder_pipeline_model_parallel_size,
            ),
            flush=True,
        )

    # backwards compatibility.
    if args.pipeline_model_parallel_split_rank is not None:
        args.encoder_pipeline_model_parallel_size = args.pipeline_model_parallel_split_rank
        args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size
        assert args.pipeline_model_parallel_size > 0

    if args.tp_comm_overlap:
        assert args.sequence_parallel == True, "Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled"

    # Deprecated arguments
    assert args.batch_size is None, "--batch-size argument is no longer " "valid, use --micro-batch-size instead"
    del args.batch_size
    assert args.warmup is None, "--warmup argument is no longer valid, use " "--lr-warmup-fraction instead"
    del args.warmup
    assert args.model_parallel_size is None, "--model-parallel-size is no " "longer valid, use --tensor-model-parallel-size instead"
    del args.model_parallel_size

    if args.checkpoint_activations:
        if args.rank == 0:
            print("--checkpoint-activations is no longer valid, use --recompute-activations, " "or, for more control, --recompute-granularity and --recompute-method.")
        exit()
    del args.checkpoint_activations

    if args.recompute_activations:
        args.recompute_granularity = "selective"
    del args.recompute_activations

    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key, None) is not None:
            if args.rank == 0:
                print(
                    "WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}".format(
                        key=key, v=defaults[key], v2=getattr(args, key)
                    ),
                    flush=True,
                )
        else:
            setattr(args, key, defaults[key])

    if args.data_path is not None and args.split is None:
        legacy_default_split_value = "969, 30, 1"
        if args.rank == 0:
            print("WARNING: Please specify --split when using --data-path. Using legacy default value " f'of "{legacy_default_split_value}"')
        args.split = legacy_default_split_value

    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print("setting global batch size to {}".format(args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.num_layers_per_virtual_pipeline_stage is not None:
        if args.overlap_p2p_comm:
            assert args.pipeline_model_parallel_size > 1, "when interleaved schedule is used, pipeline-model-parallel size " "should be greater than 1"
        else:
            assert args.pipeline_model_parallel_size > 2, (
                "when interleaved schedule is used and p2p communication overlap is disabled, "
                "pipeline-model-parallel size should be greater than 2 to avoid having multiple "
                "p2p sends and recvs between same 2 ranks per communication batch"
            )
        assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, "number of layers should be divisible by the pipeline parallel size"
        num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
        assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, "number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage"
        args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
        # Overlap P2P communication is disabled if not using the interleaved schedule.
        args.overlap_p2p_comm = False
        args.align_param_gather = False
        if args.rank == 0:
            print(
                "WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False " "since non-interleaved schedule does not support overlapping p2p communication " "and aligned param AG"
            )

    if args.overlap_param_gather:
        assert args.use_distributed_optimizer, "--overlap-param-gather only supported with distributed optimizer"
        assert args.overlap_grad_reduce, "Must use --overlap-param-gather with --overlap-grad-reduce"
        assert not args.use_legacy_models, "--overlap-param-gather only supported with MCore models"

    if args.overlap_param_gather_with_optimizer_step:
        assert args.use_distributed_optimizer, "--overlap-param-gather-with-optimizer-step only supported with distributed optimizer"
        assert args.overlap_param_gather, "Must use --overlap-param-gather-with-optimizer-step with --overlap-param-gather"
        assert args.virtual_pipeline_model_parallel_size is not None, "--overlap-param-gather-with-optimizer-step only supported with interleaved pipeline parallelism"
        assert not args.use_dist_ckpt, "--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet"

    if args.fp8_param_gather:
        assert args.use_distributed_optimizer, "--fp8-param-gather only supported with distributed optimizer"

    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
        assert not args.bf16
        args.params_dtype = torch.half
        # Turn off checking for NaNs in loss and grads if using dynamic loss scaling,
        # where NaNs in grads / loss are signal to the loss scaler.
        if not args.loss_scale:
            args.check_for_nan_in_loss_and_grad = False
            if args.rank == 0:
                print("WARNING: Setting args.check_for_nan_in_loss_and_grad to False since " "dynamic loss scaling is being used")
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print("accumulate and all-reduce gradients in fp32 for " "bfloat16 data type.", flush=True)

    if args.rank == 0:
        print("using {} for parameters ...".format(args.params_dtype), flush=True)

    if args.dataloader_type is None:
        args.dataloader_type = "single"

    # data
    assert args.num_dataset_builder_threads > 0

    # Consumed tokens.
    args.consumed_train_samples = 0
    args.skipped_train_samples = 0
    args.consumed_valid_samples = 0

    # Support for variable sequence lengths across batches/microbatches.
    # set it if the dataloader supports generation of variable sequence lengths
    # across batches/microbatches. Due to additional communication overhead
    # during pipeline parallelism, it should not be set if sequence length
    # is constant during training.
    args.variable_seq_lengths = False

    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, "expected iteration-based training"
        assert args.lr_decay_samples is None, "expected iteration-based learning rate decay"
        assert args.lr_warmup_samples == 0, "expected iteration-based learning rate warmup"
        assert args.rampup_batch_size is None, "expected no batch-size rampup for iteration-based training"
        if args.lr_warmup_fraction is not None:
            assert args.lr_warmup_iters == 0, "can only specify one of lr-warmup-fraction and lr-warmup-iters"

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, "expected sample-based training"
        assert args.lr_decay_iters is None, "expected sample-based learning rate decay"
        assert args.lr_warmup_iters == 0, "expected sample-based learnig rate warmup"
        if args.lr_warmup_fraction is not None:
            assert args.lr_warmup_samples == 0, "can only specify one of lr-warmup-fraction " "and lr-warmup-samples"

    if args.num_layers is not None:
        assert args.encoder_num_layers is None, "cannot have both num-layers and encoder-num-layers specified"
        args.encoder_num_layers = args.num_layers
    else:
        assert args.encoder_num_layers is not None, "either num-layers or encoder-num-layers should be specified"
        args.num_layers = args.encoder_num_layers

    # Check required arguments.
    required_args = ["num_layers", "hidden_size", "num_attention_heads", "max_position_embeddings"]
    for req_arg in required_args:
        _check_arg_is_not_none(args, req_arg)

    # Checks.
    if args.ffn_hidden_size is None:
        if args.swiglu:
            # reduce the dimnesion for MLP since projections happens on
            # two linear layers. this keeps the number of paramters in
            # the same ballpark as the counterpart with 4*h size
            # we keep it a multiple of 64, which means the actual tensor size
            # will be a multiple of 64 / tp_size
            args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
        else:
            args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None and args.context_parallel_size > 1:
        assert args.seq_length % (args.context_parallel_size * 2) == 0, "seq-length should be a multiple of 2 * context-parallel-size " "if context-parallel-size > 1."

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length

    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
    if args.save is not None:
        assert args.save_interval is not None
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, "lm cross entropy in fp16 only support in fp16 mode."
    if args.fp32_residual_connection:
        assert args.fp16 or args.bf16, "residual connection in fp32 only supported when using fp16 or bf16."

    if args.moe_grouped_gemm:
        assert args.bf16, "Currently GroupedGEMM for MoE only supports bf16 dtype."
        # dc = torch.cuda.get_device_capability() #? 注释掉 cuda 部分
        # assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels."

    if args.weight_decay_incr_style == "constant":
        assert args.start_weight_decay is None
        assert args.end_weight_decay is None
        args.start_weight_decay = args.weight_decay
        args.end_weight_decay = args.weight_decay
    else:
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None

    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print("Persistent fused layer norm kernel is supported from " "pytorch v1.11 (nvidia pytorch container paired with v1.11). " "Defaulting to no_persist_layer_norm=True")

    # Activation recomputing.
    if args.distribute_saved_activations:
        assert args.tensor_model_parallel_size > 1, "can distribute " "recomputed activations only across tensor model " "parallel groups"
        assert args.recompute_granularity == "full", "distributed recompute activations is only " "application to full recompute granularity"
        assert args.recompute_method is not None, "for distributed recompute activations to work you " "need to use a recompute method "
        assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), (
            "distributed recompute activations are supported for pytorch " "v1.10 and above (Nvidia Pytorch container >= 21.07). Current " "pytorch version is v%s.%s." % (TORCH_MAJOR, TORCH_MINOR)
        )

    if args.recompute_granularity == "selective":
        assert args.recompute_method is None, "recompute method is not yet supported for " "selective recomputing granularity"

    # disable sequence parallelism when tp=1
    # to avoid change in numerics when
    # sequence_parallelism is enabled.
    if args.tensor_model_parallel_size == 1:
        args.sequence_parallel = False

    # disable async_tensor_model_parallel_allreduce when
    # model parallel memory optimization is enabled
    if args.sequence_parallel:
        args.async_tensor_model_parallel_allreduce = False

    if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") != "1":
        if args.sequence_parallel:
            raise RuntimeError("Using sequence parallelism requires setting the environment variable " "CUDA_DEVICE_MAX_CONNECTIONS to 1")
        if args.async_tensor_model_parallel_allreduce:
            raise RuntimeError("Using async gradient all reduce requires setting the environment " "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")

    # Disable bias gelu fusion if we are disabling bias altogether
    if not args.add_bias_linear:
        args.bias_gelu_fusion = False

    # Retro checks.
    if args.retro_add_retriever:

        # Train samples should be auto-loaded.
        assert args.train_samples is not None, "args.train_samples should be auto-loaded from the retro config."

        # Sequence parallelism unsupported.
        assert not args.sequence_parallel, "retro currently does not support sequence parallelism."

        # Pipeline parallelism unsupported.
        assert args.pipeline_model_parallel_size == 1, "retro currently does not support pipeline parallelism."

    if args.decoupled_lr is not None or args.decoupled_min_lr is not None:
        assert not args.use_legacy_models, "--decoupled-lr and --decoupled-min-lr is not supported in legacy models."

    # Legacy RoPE arguments
    if args.use_rotary_position_embeddings:
        args.position_embedding_type = "rope"
    if args.rotary_interleaved and args.apply_rope_fusion:
        raise RuntimeError("--rotary-interleaved does not work with rope_fusion.")
    if args.rotary_interleaved and args.use_legacy_models:
        raise RuntimeError("--rotary-interleaved is not supported in legacy models.")

    # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
    # don't allow it to keep things simple
    if not args.add_position_embedding and args.position_embedding_type != "rope":
        raise RuntimeError("--no-position-embedding is deprecated, use --position-embedding-type")

    # MoE Spec check
    if args.num_experts == 0:
        args.num_experts = None
    if args.num_experts is not None:
        assert args.spec is None, "Model Spec must be None when using MoEs"

    # Context parallel
    if args.context_parallel_size > 1:
        assert not args.use_legacy_models, "Context parallelism is not supported in legacy models."

    # Expert parallelism check
    if args.expert_model_parallel_size > 1:
        assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism"
        assert args.num_experts % args.expert_model_parallel_size == 0, "Number of experts should be a multiple of expert model parallel_size."
        assert not args.fp16, "Expert parallelism is not supported with fp16 training."

    # Distributed checkpointing checks
    # if args.use_dist_ckpt and args.use_legacy_models:
    #     raise RuntimeError('--use-dist-ckpt is not supported in legacy models.')

    # Data blend checks
    assert args.mock_data + bool(args.data_path) + any([args.train_data_path, args.valid_data_path, args.test_data_path]) <= 1, "A single data source must be provided in training mode, else None"

    if args.use_tp_pp_dp_mapping:
        assert args.context_parallel_size * args.expert_model_parallel_size <= 1, "context_parallel and expert_model_parallel can't be used with tp-pp-dp mapping."

    # Deterministic mode
    if args.deterministic_mode:
        assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
        assert not args.cross_entropy_loss_fusion, "Cross Entropy Fusion is currently not deterministic."

        all_reduce_choices = ["Tree", "Ring", "CollnetDirect", "CollnetChain", "^NVLS"]
        assert os.getenv("NCCL_ALGO", -1) != -1 and os.getenv("NCCL_ALGO") in all_reduce_choices, f"NCCL_ALGO must be one of {all_reduce_choices}."

        torch.use_deterministic_algorithms(True)

    # Update the printed args to reflect that `apply_query_key_layer_scaling` also controls `attention_softmax_in_fp32`
    if args.apply_query_key_layer_scaling:
        args.attention_softmax_in_fp32 = True

    # Checkpointing
    if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
        print("--ckpt-fully-parallel-save flag is deprecated and has no effect." " Use --no-ckpt-fully-parallel-save to disable parallel save.")
    # if (
    #     args.use_dist_ckpt
    #     and not args.ckpt_fully_parallel_save
    #     and args.use_distributed_optimizer
    #     and args.rank == 0
    # ):
    #     print('Warning: With non-parallel ckpt save and DistributedOptimizer,'
    #           ' it will be impossible to resume training with different parallelism.'
    #           ' Consider removing flag --no-ckpt-fully-parallel-save.')
    # if args.use_dist_ckpt_deprecated and args.rank == 0:
    #     print('--use-dist-ckpt is deprecated and has no effect.'
    #           ' Use --ckpt-format to select the checkpoint format.')
    # if args.dist_ckpt_format_deprecated and args.rank == 0:
    #     print('--dist-ckpt-format is deprecated and has no effect.'
    #           ' Use --ckpt-format to select the checkpoint format.')

    # MoE upcycling check
    if args.moe_use_upcycling:
        assert args.save is not None, "When using upcycling, the --save option must be specified."
        if not args.no_load_optim:
            args.no_load_optim = True
            print("Warning: disabling --no-load-optim for upcycling.")
        if not args.no_load_rng:
            args.no_load_rng = True
            print("Warning: disabling --no-load-rng for upcycling.")

    # Print arguments.
    # _print_args("arguments", args)

    return args
